{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training cheetah bot to run using PPO\n",
    "In this section, learn how to train the 2D cheetah bot to run using PPO. First, import the\n",
    "necessary libraries: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "from stable_baselines.common.policies import MlpPolicy\n",
    "from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize\n",
    "from stable_baselines import PPO2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a vectorized environment using `DummyVecEnv`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = DummyVecEnv([lambda: gym.make(\"HalfCheetah-v2\")])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Normalize the states (observations):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = VecNormalize(env,norm_obs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instantiate the agent:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = PPO2(MlpPolicy, env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can train the agent:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.learn(total_timesteps=250000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also have a look at how our trained agent performs in the environment:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = env.reset()\n",
    "while True:\n",
    "    action, _ = agent.predict(state)\n",
    "    next_state, reward, done, info = env.step(action)\n",
    "    state = next_state\n",
    "    env.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the whole code used in this section in a python file called ppo.py and then open\n",
    "terminal and run the file:\n",
    "\n",
    "`python ppo.py`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Making a GIF of a trained agent\n",
    "\n",
    "In the previous section, we learned how to train the cheetah bot to run using PPO. Can we\n",
    "also create a GIF file of our trained agent? Yes! Let's see how to do that.\n",
    "\n",
    "\n",
    "First import the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import imageio\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the list for storing images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the state by resetting the environment where the agent is the agent we trained in\n",
    "the previous section:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = agent.env.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Render the environment and get the image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = agent.env.render(mode='rgb_array')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For every step in the environment, save the image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(500):\n",
    "    images.append(img)\n",
    "    action, _ = agent.predict(state)\n",
    "    next_state, reward, done ,info = agent.env.step(action)\n",
    "    state = next_state\n",
    "    img = agent.env.render(mode='rgb_array')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the GIF file as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imageio.mimsave('HalfCheetah.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will have a new file called `HalfCheetah.gif`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
